import torch
import torch.nn as nn
from Network.network_utils import pytorch_model, initialize_optimizer
import numpy as np

class InferenceModule(nn.Module):
    # Pytorch specific operations localized in superclass
    def __init__(self, args, extractor):
        super().__init__()
        self.iscuda=args.torch.cuda
        self.device = args.torch.gpu
        self.test = False # puts the module in testing mode
        self.extractor = extractor
        self.optimizer_args = args.network.optimizer
    

    def init_optimizer(self, args):
        # initializes all the optimizers here
        self.optimizer_args = args.network.optimizer
        self.optimizer = initialize_optimizer(self.model, self.optimizer_args, self.optimizer_args.lr) if self.model is not None else None
        if hasattr(self, "inter_model"): 
            self.inter_optimizer = initialize_optimizer(self.inter_model, self.optimizer_args, self.optimizer_args.alt_lr) if self.inter_model is not None else None
            self.both_optimizer = initialize_optimizer(list(self.model.parameters()) + list(self.inter_model.parameters()), self.optimizer_args, self.optimizer_args.lr, params=True) if self.inter_model is not None else None

    def return_model_optimizer(self, inter=False, both=False): 
        if both: return self, self.both_optimizer
        return self.inter_model if inter else self.model, self.inter_optimizer if inter else self.optimizer

    def assign_from_model(self, model, inter_model=None):
        if hasattr(self, "model") and self.model is not None: self.model = model
        if hasattr(self, "inter_model") and self.inter_model is not None and inter_model is not None: self.inter_model = inter_model
        if hasattr(self, "optimizer") and self.optimizer is not None: self.optimizer = initialize_optimizer(self.model, self.optimizer_args, self.optimizer_args.lr) if self.model is not None else None
        if hasattr(self, "inter_optimizer") and self.inter_optimizer is not None: self.inter_optimizer = initialize_optimizer(self.inter_model, self.optimizer_args, self.optimizer_args.alt_lr) if self.inter_model is not None else None
        if hasattr(self, "both_optimizer") and self.both_optimizer is not None: self.both_optimizer = initialize_optimizer(list(self.model.parameters()) + list(self.inter_model.parameters()), self.optimizer_args, self.optimizer_args.lr, params=True) if self.inter_model is not None else None


    def cuda(self, device=None):
        super().cuda()
        self.iscuda = True
        if hasattr(self, "model") and self.model is not None: self.model.cuda(device=device)
        if hasattr(self, "inter_model") and self.inter_model is not None: self.inter_model.cuda(device=device)
        if hasattr(self, "optimizer") and self.optimizer is not None: self.optimizer = initialize_optimizer(self.model, self.optimizer_args, self.optimizer_args.lr) if self.model is not None else None
        if hasattr(self, "inter_optimizer") and self.inter_optimizer is not None: self.inter_optimizer = initialize_optimizer(self.inter_model, self.optimizer_args, self.optimizer_args.alt_lr) if self.inter_model is not None else None
        if hasattr(self, "both_optimizer") and self.both_optimizer is not None: self.both_optimizer = initialize_optimizer(list(self.model.parameters()) + list(self.inter_model.parameters()), self.optimizer_args, self.optimizer_args.lr, params=True) if self.inter_model is not None else None
        if device is not None: self.device = device
        return self

    def cpu(self):
        super().cpu()
        if hasattr(self, "model") and self.model is not None and self.model.iscuda: self.model.cpu()
        if hasattr(self, "inter_model") and self.inter_model is not None and self.inter_model.iscuda: self.inter_model.cpu()
        if hasattr(self, "optimizer") and self.optimizer is not None: self.optimizer = initialize_optimizer(self.model, self.optimizer_args, self.optimizer_args.lr) if self.model is not None else None
        if hasattr(self, "inter_optimizer") and self.inter_optimizer is not None: self.inter_optimizer = initialize_optimizer(self.inter_model, self.optimizer_args, self.optimizer_args.alt_lr) if self.inter_model is not None else None
        if hasattr(self, "both_optimizer") and self.both_optimizer is not None: self.both_optimizer = initialize_optimizer(list(self.model.parameters()) + list(self.inter_model.parameters()), self.optimizer_args, self.optimizer_args.lr, params=True) if self.inter_model is not None else None
        self.iscuda = False
        return self
    
    def get_output(self, model_params, name, extractor, batch):
        if model_params.predict_dynamics: 
            if name == "all": return batch.target_diff
            else: return extractor.get_named_target(batch.target_diff, names=name)
        else:
            if name == "all": return batch.next_target
            else: return extractor.get_named_target(batch.next_target, names=name)
        
    def get_omit(self, batch, keep_all=False, keep_invalid=False, use_name=""):
        # gets flags for done or invalid for use_name, if keep_all then just returns ones
        
        name = use_name if len(use_name) > 0 else self.name
        # manually squeezing since we don't want length 1 vectors to get broken
        if len(batch.done.shape) == 1:
            if keep_all: omit_flags = np.ones(batch.done.shape).astype(bool)
            elif keep_invalid: omit_flags = (1-(batch.done)).nonzero()
            else: omit_flags = (1-(batch.done.astype(bool) + (1-batch.valid[:, self.extractor.get_index(name)]).astype(bool)).astype(int)).nonzero()
        else: # assume that the shaep is 2
            dones = batch.done[:,0]
            if keep_all: omit_flags = np.ones(dones.shape).astype(bool)
            elif keep_invalid: omit_flags = (1-(dones)).nonzero()
            else: omit_flags = (1-(dones.squeeze().astype(bool) + (1-batch.valid[:, self.extractor.get_index(name)]).astype(bool)).astype(int)).nonzero()

        # TODO: could take in valid vector, if given invalidates something
        return omit_flags


    def _target_dists(self, batch, params, skip=None):
        # gets the log probabilities and the distributions
        out_state = self.get_output(self.mp, self.name if self.name.find('->') == -1 else self.target, self.extractor, batch)
        target = pytorch_model.wrap(out_state, cuda=self.iscuda) 
        # this peculiar logic is for when there are multiple models that output the same
        # values (i.e. cluster style models)
        num_param_sets = int(params[0].shape[-1] // target.shape[-1]) if params[0].shape[-1] > target.shape[-1] else 1
        log_probs = list()
        for i in range(num_param_sets):
            if skip is not None and skip[i] == 0: continue # only add the non-skipped params
            new_params = [p[..., target.shape[-1]*i:target.shape[-1] * (i+1)] for p in params]
            dist = self.forward_dist(*new_params)
            log_probs.append(dist.log_prob(target))
        log_probs = torch.cat(log_probs, dim=-1)
        return target, log_probs

    def _single_index_all(self, name, extractor, params, mask, info):
        tidx = extractor.get_index(name)
        params = (params[0][:,tidx], params[1][:,tidx])
        mask = mask[:,tidx]
        for i in range(len(info)):
            info[i] = info[i][:,tidx]
        return params, mask, info

    def __call__(self, batch, valid, extractor, normalizer, additional=[]):
        '''
        Performs inference, meaning it will populate params, log_probs, mask at the minimum
        expects batch to contain the necessary information for inference
        expects valid to be binary [batch, num_objects]
        extractor is used to get the flattened representation from batch
        normalizer is unnecessary, but used for return
        additional should have the same names as ret_settings in Network.Dists, and the same naming
            convention as other modules
        '''
        pass # implemented in subclasses